// Copyright © 2025 Bjango. All rights reserved.

#pragma once

#import <AudioToolbox/AudioToolbox.h>
#import <AVFoundation/AVFoundation.h>
#include <array>
#include <unordered_map>
#include "NoteMap.hpp"

enum class MIDINoteEventType {
    NoteOn,
    NoteOff,
    Other
};

class MidiProcessor {
public:
    MidiProcessor() {
        noteOnOffsets.resize(128);
    }
    
    typedef void (*StepInfoCallback)(void* context, int currentStepIndex);

    void setStepInfoCallback(StepInfoCallback callback, void* context) {
        stepInfoCallback = callback;
        stepInfoCallbackContext = context;
    }
    
    //MARK: - Configurable Parameters
    void setBypass(bool shouldBypass) {
        bypassed = shouldBypass;
    }
    
    bool isBypassed() {
        return bypassed;
    }
    
    void setLicensed(bool isLicensed) {
        licensed = isLicensed;
    }
    
    void setFixedDelay(float delaySeconds) {
        fixedDelaySeconds = delaySeconds;
    }
    
    void setRandomDelay(float delaySeconds) {
        randomDelaySeconds = delaySeconds;
    }
    
    void setSeed(float seedValue) {
        seed = seedValue;
    }
    
    void setSwingOneFour(float swing) {
        swingOneFour = swing / 100.0f;
    }
    
    void setSwingOneEight(float swing) {
        swingOneEight = swing / 100.0f;
    }
    
    void setSwingOneSixteen(float swing) {
        swingOneSixteen = swing / 100.0f;
    }
    
    void setSampleRate(float rate) {
        sampleRate = rate;
    }
    
    void setStepDelay(int stepPosition, float delaySeconds) {
        stepDelays[stepPosition] = delaySeconds;
    }
    
    void setStepVelocity(int stepPosition, float velocity) {
        stepVelocities[stepPosition] = (velocity / 100.0f);
    }
    
    void setStepDelayCount(int count){
        if (count < 0 || count > 16) { return; }
        
        stepDelayCount = count;
    }
    
    //MARK: - Main rendering block
    AUInternalRenderBlock internalRenderBlock() {
        return ^AUAudioUnitStatus(AudioUnitRenderActionFlags 				*actionFlags,
                                  const AudioTimeStamp       				*timestamp,
                                  AUAudioFrameCount           				frameCount,
                                  NSInteger                   				outputBusNumber,
                                  AudioBufferList            				*outputData,
                                  const AURenderEvent        				*realtimeEventListHead,
                                  AURenderPullInputBlock __unsafe_unretained pullInputBlock) {
            
            // get music context
            if (!musicContextBlock(&currentTempo, nullptr, nullptr, &currentBeatPosition, nullptr, &currentMeasureDownbeatPosition)) {
                return kAudioUnitErr_CannotDoInCurrentContext;
            }

            currentBeatInMeasure = currentBeatPosition - currentMeasureDownbeatPosition;
            
            // Pre-calculate timing constants for the current buffer
            double secondsPerSample = 1.0 / sampleRate;
            double beatsPerSecond = currentTempo / 60.0;
            double beatsPerSample = beatsPerSecond * secondsPerSample;
            
            // Calculate step index at the start of the buffer for UI reporting
            int startStepIndex = calculateStepIndex(currentBeatPosition);

            const AURenderEvent *event = realtimeEventListHead;
            while (event) {
                if (event->head.eventType == AURenderEventMIDI) {
                    const AUMIDIEvent &midi = event->MIDI;
                    
                    if (bypassed || !licensed) {
                        if (midiEventBlock) { midiEventBlock(midi.eventSampleTime, midi.cable, midi.length, midi.data); }
                    } else {
                        // Calculate per-event timing
                        AUEventSampleTime eventTime = midi.eventSampleTime;
                        double sampleOffset = 0;
                        if (eventTime != AUEventSampleTimeImmediate) {
                            sampleOffset = (double)(eventTime - (AUEventSampleTime)timestamp->mSampleTime);
                        }
                        
                        double eventBeatPosition = currentBeatPosition + (sampleOffset * beatsPerSample);
                        double eventBeatInMeasure = eventBeatPosition - currentMeasureDownbeatPosition;
                        int eventStepIndex = calculateStepIndex(eventBeatPosition);

                        const uint8_t *data = midi.data;
                        
                        const uint8_t statusByte = data[0];
                        const uint8_t note = midi.length > 1 ? data[1] : 0;
                        uint8_t velocity = midi.length > 2 ? data[2] : 0;
                        MIDINoteEventType midiType = midiNoteEventType(statusByte, velocity);
                        
                        AUEventSampleTime newTime = midi.eventSampleTime;
                        if (midiType == MIDINoteEventType::NoteOn) {
                            float randomDelaySamples = calculateRandomDelay(note, midiType, eventBeatPosition) * sampleRate;
                            newTime = calculateDelayTime(midi.eventSampleTime, randomDelaySamples, eventStepIndex, eventBeatInMeasure);
                            
                            noteOnOffsets.insert(note, randomDelaySamples);
                            velocity = calculateVelocity(velocity, eventStepIndex);
                        } else if (midiType == MIDINoteEventType::NoteOff) {
                            float noteOnRandomSamples;
                            if (noteOnOffsets.pop(note, noteOnRandomSamples)) {
                                // change the off offset based on the random offset applied to the on note
                                newTime = calculateDelayTime(midi.eventSampleTime, noteOnRandomSamples, eventStepIndex, eventBeatInMeasure);
                            }
                            velocity = calculateVelocity(velocity, eventStepIndex);
                        }
                        
                        uint8_t buffer[3] = { statusByte, note, velocity };
                        if (midiEventBlock) { midiEventBlock(newTime, midi.cable, midi.length, buffer); }
                    }
                }
                
                event = event->head.next;
            }
            
            reportStepInfo(startStepIndex);
            
            return noErr;
        };
    }
    
    //MARK: - Callback blocks
    void setMIDIEventBlock(AUMIDIOutputEventBlock midiBlock) {
        midiEventBlock = midiBlock;
    }
    
    void setMusicContextBlock(AUHostMusicalContextBlock musicBlock) {
        musicContextBlock = musicBlock;
    }
    
    void setTransportContextBlock(AUHostTransportStateBlock transportBlock) {
        transportContextBlock = transportBlock;
    }
private:
    AUEventSampleTime calculateDelayTime(AUEventSampleTime eventTime, float randomDelaySamples, int stepIndex, double beatInMeasure) {
        float fixedDelaySamples = fixedDelaySeconds * sampleRate;
        
        double oneFourBeatMeasure = calculateSwingResolution(4, swingOneFour, beatInMeasure);
        float oneFourOffsetSamples = beatInMeasureDeltaToSamples(beatInMeasure, oneFourBeatMeasure);
        
        double oneEightBeatMeasure = calculateSwingResolution(8, swingOneEight, beatInMeasure);
        float oneEightOffsetSamples = beatInMeasureDeltaToSamples(beatInMeasure, oneEightBeatMeasure);
        
        double oneSixteenBeatMeasure = calculateSwingResolution(16, swingOneSixteen, beatInMeasure);
        float oneSixteenOffsetSamples = beatInMeasureDeltaToSamples(beatInMeasure, oneSixteenBeatMeasure);
        
        float stepDelaySamples = stepDelays[stepIndex] * sampleRate;
        
        return eventTime + fixedDelaySamples + randomDelaySamples + oneFourOffsetSamples + oneEightOffsetSamples + oneSixteenOffsetSamples + stepDelaySamples;
    }
    
    uint8_t calculateVelocity(uint8_t currentVelocity, int stepIndex) {
        return static_cast<uint8_t>(currentVelocity * stepVelocities[stepIndex]);
    }
    
    float calculateRandomDelay(uint8_t note, MIDINoteEventType midiType, double beatPosition) {
        if (randomDelaySeconds == 0 || midiType == MIDINoteEventType::Other) { return 0; }
        
        float notePitch = static_cast<float>(note);
        float random = sin((beatPosition + notePitch) * 281.53 + seed * 57.15) * 131.13;
        float randomFract = random - floor(random);
        return randomDelaySeconds * randomFract;
    }

    int calculateStepIndex(double beatPosition) {
        if (stepDelayCount <= 1) return 0;

        constexpr double kEpsilon = 1e-4;
        double sixteenthIndex = beatPosition * 4.0;
        if (sixteenthIndex > -kEpsilon && sixteenthIndex < 0.0) {
            sixteenthIndex = 0.0;
        }

        long idx = static_cast<long>(std::floor(sixteenthIndex + kEpsilon));
        idx = idx % stepDelayCount;
        if (idx < 0) {
            idx += stepDelayCount;
        }

        return static_cast<int>(idx);
    }
    
    float calculateSwingResolution(int stepsPerMeasure, float swingAmount, double beatInMeasure) {
          if (swingAmount == 0.5f || stepsPerMeasure <= 0) return beatInMeasure;

          // Step 2: Convert to fractional step space
          float stepIndex = beatInMeasure * stepsPerMeasure / 4.0;
          int baseStep = floor(stepIndex);
          float stepFraction = stepIndex - baseStep;

          // Step 3: Find even-numbered pair start and determine if it's first or second in pair
          int pairStart = (baseStep / 2) * 2;
          bool isFirstInPair = (baseStep % 2 == 0);

          // Step 4: Compute durations within pair
          float stepsPerBeat = stepsPerMeasure / 4.0;
          float totalDuration = 2.0 / stepsPerBeat; // Two steps worth in beats
          float firstDuration = totalDuration * swingAmount;
          float secondDuration = totalDuration * (1.0 - swingAmount);

          // Step 5: Compute new beat offset within the pair
          float newOffset;
          if (isFirstInPair) {
            newOffset = pairStart * (1.0 / stepsPerBeat) + stepFraction * firstDuration;
          } else {
            newOffset = pairStart * (1.0 / stepsPerBeat) + firstDuration + stepFraction * secondDuration;
          }
          
          return newOffset;
    }
    
    float beatInMeasureDeltaToSamples(float oldBeatInMeasure, float newBeatInMeasure) {
        float beatDelta = newBeatInMeasure - oldBeatInMeasure;
        float secondsPerBeat = 60 / currentTempo;
        float seconds = beatDelta * secondsPerBeat;
        float sampleOffset = round(seconds * sampleRate);
        
        return sampleOffset;
    }
    
    void reportStepInfo(int stepIndex) {
        if (stepIndex == lastReportedStepIndex) { return; }
        
        if (stepInfoCallback) {
            lastReportedStepIndex = stepIndex;
            stepInfoCallback(stepInfoCallbackContext, stepIndex);
        }
    }
    
    MIDINoteEventType midiNoteEventType(uint8_t statusByte, uint8_t velocity) {
        uint8_t status = statusByte & 0xF0;

        switch (status) {
            case 0x90:
                return (velocity > 0) ? MIDINoteEventType::NoteOn : MIDINoteEventType::NoteOff;
            case 0x80:
                return MIDINoteEventType::NoteOff;
            default:
                return MIDINoteEventType::Other;
        }
    }
    
    AUMIDIOutputEventBlock midiEventBlock;
    
    // musical context varables
    AUHostMusicalContextBlock musicContextBlock;
    double currentTempo = 0;
    double currentBeatPosition = 0;
    double currentMeasureDownbeatPosition = 0;
    double currentBeatInMeasure = 0;
    
    AUHostTransportStateBlock transportContextBlock;
    
    float sampleRate = 44100;
    
    bool licensed = true;
    
    // configurable params
    bool bypassed = false;
    float swingOneFour = 0.5f;
    float swingOneEight = 0.5f;
    float swingOneSixteen = 0.5f;
    
    float randomDelaySeconds = 0.0f;
    float fixedDelaySeconds = 0.0f;
    float seed = 32.0f;
    
    std::array<float, 16> stepDelays = {0.0f};
    std::array<float, 16> stepVelocities = {0.0f};
    int stepDelayCount = 16;
    int lastReportedStepIndex = -1;
    
    // we store note on events random offsets so we can use the same one for the note off, avoiding any issues with randomness causing note off to precede note on
    NoteMap noteOnOffsets;
    
    StepInfoCallback stepInfoCallback = nullptr;
    void* stepInfoCallbackContext = nullptr;
};
